{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:42.157792Z",
     "iopub.status.busy": "2021-08-12T16:24:42.157239Z",
     "iopub.status.idle": "2021-08-12T16:24:42.533353Z",
     "shell.execute_reply": "2021-08-12T16:24:42.532848Z",
     "shell.execute_reply.started": "2021-08-12T16:24:42.157741Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import cv2\n",
    "import glob\n",
    "\n",
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:43.081830Z",
     "iopub.status.busy": "2021-08-12T16:24:43.081287Z",
     "iopub.status.idle": "2021-08-12T16:24:43.571150Z",
     "shell.execute_reply": "2021-08-12T16:24:43.570091Z",
     "shell.execute_reply.started": "2021-08-12T16:24:43.081781Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, sys, codecs, glob\n",
    "from PIL import Image, ImageDraw\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import cv2\n",
    "\n",
    "import torch\n",
    "torch.backends.cudnn.benchmark = False\n",
    "# torch.backends.cudnn.enabled = False\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:43.572257Z",
     "iopub.status.busy": "2021-08-12T16:24:43.572074Z",
     "iopub.status.idle": "2021-08-12T16:24:43.576630Z",
     "shell.execute_reply": "2021-08-12T16:24:43.575869Z",
     "shell.execute_reply.started": "2021-08-12T16:24:43.572241Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, img_group, transform):\n",
    "        self.img_path = img_path\n",
    "        self.transform = transform\n",
    "        self.group = img_group\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        return img, self.group[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:43.737602Z",
     "iopub.status.busy": "2021-08-12T16:24:43.737057Z",
     "iopub.status.idle": "2021-08-12T16:24:43.768302Z",
     "shell.execute_reply": "2021-08-12T16:24:43.767722Z",
     "shell.execute_reply.started": "2021-08-12T16:24:43.737555Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>label</th>\n",
       "      <th>path</th>\n",
       "      <th>group</th>\n",
       "      <th>fold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>008233.jpg</td>\n",
       "      <td>008233.jpg 006688.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/008233.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>006688.jpg</td>\n",
       "      <td>008233.jpg 006688.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/006688.jpg</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>000232.jpg</td>\n",
       "      <td>000232.jpg 003552.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/000232.jpg</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>003552.jpg</td>\n",
       "      <td>000232.jpg 003552.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/003552.jpg</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>000814.jpg</td>\n",
       "      <td>000814.jpg 013765.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/000814.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>013765.jpg</td>\n",
       "      <td>000814.jpg 013765.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/013765.jpg</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>001429.jpg</td>\n",
       "      <td>001429.jpg 014834.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/001429.jpg</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>014834.jpg</td>\n",
       "      <td>001429.jpg 014834.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/014834.jpg</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>012795.jpg</td>\n",
       "      <td>012795.jpg 015860.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/012795.jpg</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>015860.jpg</td>\n",
       "      <td>012795.jpg 015860.jpg</td>\n",
       "      <td>./电商图像检索_数据集/train/015860.jpg</td>\n",
       "      <td>4</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         name                  label                           path  group  \\\n",
       "0  008233.jpg  008233.jpg 006688.jpg  ./电商图像检索_数据集/train/008233.jpg      0   \n",
       "1  006688.jpg  008233.jpg 006688.jpg  ./电商图像检索_数据集/train/006688.jpg      0   \n",
       "2  000232.jpg  000232.jpg 003552.jpg  ./电商图像检索_数据集/train/000232.jpg      1   \n",
       "3  003552.jpg  000232.jpg 003552.jpg  ./电商图像检索_数据集/train/003552.jpg      1   \n",
       "4  000814.jpg  000814.jpg 013765.jpg  ./电商图像检索_数据集/train/000814.jpg      2   \n",
       "5  013765.jpg  000814.jpg 013765.jpg  ./电商图像检索_数据集/train/013765.jpg      2   \n",
       "6  001429.jpg  001429.jpg 014834.jpg  ./电商图像检索_数据集/train/001429.jpg      3   \n",
       "7  014834.jpg  001429.jpg 014834.jpg  ./电商图像检索_数据集/train/014834.jpg      3   \n",
       "8  012795.jpg  012795.jpg 015860.jpg  ./电商图像检索_数据集/train/012795.jpg      4   \n",
       "9  015860.jpg  012795.jpg 015860.jpg  ./电商图像检索_数据集/train/015860.jpg      4   \n",
       "\n",
       "   fold  \n",
       "0     0  \n",
       "1     0  \n",
       "2     1  \n",
       "3     1  \n",
       "4     2  \n",
       "5     2  \n",
       "6     3  \n",
       "7     3  \n",
       "8     4  \n",
       "9     4  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df = pd.read_csv('./电商图像检索_数据集/train.csv')\n",
    "train_df['path'] = './电商图像检索_数据集/train/' + train_df['name']\n",
    "\n",
    "train_df['group'] = pd.factorize(train_df['label'])[0]\n",
    "train_df['fold'] = train_df['group'] % 5\n",
    "\n",
    "train_df = train_df.sort_values(by='group')\n",
    "train_df.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:44.097904Z",
     "iopub.status.busy": "2021-08-12T16:24:44.097355Z",
     "iopub.status.idle": "2021-08-12T16:24:44.109380Z",
     "shell.execute_reply": "2021-08-12T16:24:44.108621Z",
     "shell.execute_reply.started": "2021-08-12T16:24:44.097857Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "tr_path = train_df[train_df['fold'] != 0]['path'].values\n",
    "tr_label = train_df[train_df['fold'] != 0]['group']\n",
    "tr_label = pd.factorize(tr_label)[0]\n",
    "\n",
    "val_path = train_df[train_df['fold'] == 0]['path'].values\n",
    "val_label = train_df[train_df['fold'] == 0]['group'].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:44.656698Z",
     "iopub.status.busy": "2021-08-12T16:24:44.656157Z",
     "iopub.status.idle": "2021-08-12T16:24:44.667303Z",
     "shell.execute_reply": "2021-08-12T16:24:44.666708Z",
     "shell.execute_reply.started": "2021-08-12T16:24:44.656670Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class ArcModule(nn.Module):\n",
    "    def __init__(self, in_features, out_features, s = 10, m = 0.1):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.s = s\n",
    "        self.m = m\n",
    "        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))\n",
    "        nn.init.xavier_normal_(self.weight)\n",
    "\n",
    "        self.cos_m = math.cos(m)\n",
    "        self.sin_m = math.sin(m)\n",
    "        self.th = torch.tensor(math.cos(math.pi - m))\n",
    "        self.mm = torch.tensor(math.sin(math.pi - m) * m)\n",
    "\n",
    "    def forward(self, inputs, labels):\n",
    "        cos_th = F.linear(inputs, F.normalize(self.weight))\n",
    "        cos_th = cos_th.clamp(-1, 1)\n",
    "        sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2))\n",
    "        cos_th_m = cos_th * self.cos_m - sin_th * self.sin_m\n",
    "        # print(type(cos_th), type(self.th), type(cos_th_m), type(self.mm))\n",
    "        cos_th_m = torch.where(cos_th > self.th, cos_th_m, cos_th - self.mm)\n",
    "        \n",
    "        cond_v = cos_th - self.th\n",
    "        cond = cond_v <= 0\n",
    "        cos_th_m[cond] = (cos_th - self.mm)[cond]\n",
    "\n",
    "        if labels.dim() == 1:\n",
    "            labels = labels.unsqueeze(-1)\n",
    "        onehot = torch.zeros(cos_th.size()).cuda()\n",
    "        labels = labels.type(torch.LongTensor).cuda()\n",
    "        onehot.scatter_(1, labels, 1.0)\n",
    "        outputs = onehot * cos_th_m + (1.0 - onehot) * cos_th\n",
    "        outputs = outputs * self.s\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:45.069950Z",
     "iopub.status.busy": "2021-08-12T16:24:45.069410Z",
     "iopub.status.idle": "2021-08-12T16:24:45.460435Z",
     "shell.execute_reply": "2021-08-12T16:24:45.459987Z",
     "shell.execute_reply.started": "2021-08-12T16:24:45.069905Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=1536, out_features=137, bias=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import timm\n",
    "timm.create_model('efficientnet_b3', num_classes=137, \n",
    "                          pretrained=True).classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:45.461355Z",
     "iopub.status.busy": "2021-08-12T16:24:45.461185Z",
     "iopub.status.idle": "2021-08-12T16:24:45.465092Z",
     "shell.execute_reply": "2021-08-12T16:24:45.464693Z",
     "shell.execute_reply.started": "2021-08-12T16:24:45.461338Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import timm\n",
    "\n",
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self, nclass):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "                \n",
    "        model = timm.create_model('efficientnet_b3', num_classes=137, \n",
    "                          pretrained=True)\n",
    "        model.classifier = torch.nn.Identity()\n",
    "        self.model = model\n",
    "        self.margin = ArcModule(in_features=1536, \n",
    "                                out_features = nclass)\n",
    "        \n",
    "    def forward(self, img, labels=None):        \n",
    "        feat = self.model(img)\n",
    "        \n",
    "        feat = F.normalize(feat)\n",
    "        if labels is not None:\n",
    "            return self.margin(feat, labels)\n",
    "        return feat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:45.973575Z",
     "iopub.status.busy": "2021-08-12T16:24:45.973124Z",
     "iopub.status.idle": "2021-08-12T16:24:45.982535Z",
     "shell.execute_reply": "2021-08-12T16:24:45.981497Z",
     "shell.execute_reply.started": "2021-08-12T16:24:45.973542Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    model.train()\n",
    "\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        output = model(input, target)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "            \n",
    "def validate(val_loader, model):\n",
    "    model.eval()\n",
    "    \n",
    "    val_feats = []\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            val_feats.append(output.data.cpu().numpy())\n",
    "    return val_feats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:24:46.512256Z",
     "iopub.status.busy": "2021-08-12T16:24:46.511707Z",
     "iopub.status.idle": "2021-08-12T16:24:46.518591Z",
     "shell.execute_reply": "2021-08-12T16:24:46.517932Z",
     "shell.execute_reply.started": "2021-08-12T16:24:46.512209Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def set_iou(label, predict):\n",
    "    interset = set(label.split()) &  set(predict.split())\n",
    "    unionset = set(label.split()) | set(predict.split())\n",
    "    return len(interset) *1.0 / len(unionset) *1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-12T16:30:13.308268Z",
     "iopub.status.busy": "2021-08-12T16:30:13.307601Z",
     "iopub.status.idle": "2021-08-12T17:38:20.426774Z",
     "shell.execute_reply": "2021-08-12T17:38:20.425979Z",
     "shell.execute_reply.started": "2021-08-12T16:30:13.308218Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FOLD 0, Face Count 1805\n",
      "Epoch:  0\n",
      "Val 0.8352631578947368 0.5320271540529868\n",
      "Epoch:  1\n",
      "Val 0.7836842105263158 0.57745329110047\n",
      "Epoch:  2\n",
      "Val 0.7321052631578947 0.6243946567204881\n",
      "Epoch:  3\n",
      "Val 0.7063157894736842 0.6344827049794179\n",
      "Epoch:  4\n",
      "Val 0.6289473684210526 0.6595472248796937\n",
      "Epoch:  5\n",
      "Val 0.6289473684210526 0.6680417208030003\n",
      "Epoch:  6\n",
      "Val 0.5773684210526315 0.6787096838992308\n",
      "Epoch:  7\n",
      "Val 0.6289473684210526 0.6733302010324445\n",
      "FOLD 1, Face Count 1805\n",
      "Epoch:  0\n",
      "Val 0.8610526315789473 0.5498404909349248\n",
      "Epoch:  1\n",
      "Val 0.7836842105263158 0.6100791957694144\n",
      "Epoch:  2\n",
      "Val 0.7836842105263158 0.6501454066648876\n",
      "Epoch:  3\n",
      "Val 0.7063157894736842 0.6730811399264057\n",
      "Epoch:  4\n",
      "Val 0.7321052631578947 0.6847359816223564\n",
      "Epoch:  5\n",
      "Val 0.7063157894736842 0.7002876540219989\n",
      "Epoch:  6\n",
      "Val 0.6547368421052632 0.7051556184050793\n",
      "Epoch:  7\n",
      "Val 0.6289473684210526 0.7175199781825079\n",
      "FOLD 2, Face Count 1805\n",
      "Epoch:  0\n",
      "Val 0.8352631578947368 0.5416984446484325\n",
      "Epoch:  1\n",
      "Val 0.8094736842105263 0.5813628438109304\n",
      "Epoch:  2\n",
      "Val 0.7578947368421052 0.6074731984896699\n",
      "Epoch:  3\n",
      "Val 0.7063157894736842 0.6435633617507168\n",
      "Epoch:  4\n",
      "Val 0.6805263157894736 0.6498491220545971\n",
      "Epoch:  5\n",
      "Val 0.6289473684210526 0.677508359227193\n",
      "Epoch:  6\n",
      "Val 0.9126315789473685 0.5481342601060911\n",
      "Epoch:  1\n",
      "Val 0.7578947368421052 0.6168890854381741\n",
      "Epoch:  2\n",
      "Val 0.7321052631578947 0.6636471857129135\n",
      "Epoch:  3\n",
      "Val 0.7063157894736842 0.6925897676484531\n",
      "Epoch:  4\n",
      "Val 0.6547368421052632 0.6955295454415172\n",
      "Epoch:  5\n",
      "Val 0.6547368421052632 0.7052625712836981\n",
      "Epoch:  6\n",
      "Val 0.6289473684210526 0.7085951455082911\n",
      "Epoch:  7\n",
      "Val 0.6547368421052632 0.7190326909223622\n",
      "FOLD 4, Face Count 1806\n",
      "Epoch:  0\n",
      "Val 0.8610526315789473 0.5232231549467558\n",
      "Epoch:  1\n",
      "Val 0.8094736842105263 0.5969162581213321\n",
      "Epoch:  2\n",
      "Val 0.7578947368421052 0.6287194518779029\n",
      "Epoch:  3\n",
      "Val 0.7063157894736842 0.6484292045438197\n",
      "Epoch:  4\n",
      "Val 0.6805263157894736 0.6621014839830908\n",
      "Epoch:  5\n",
      "Val 0.6805263157894736 0.6791056190788811\n",
      "Epoch:  6\n",
      "Val 0.6289473684210526 0.6761130898540981\n",
      "Epoch:  7\n",
      "Val 0.6289473684210526 0.6948379550916549\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "for fold in range(5):\n",
    "    tr_path = train_df[train_df['fold'] != fold]['path'].values\n",
    "    tr_label = train_df[train_df['fold'] != fold]['group']\n",
    "    tr_label = pd.factorize(tr_label)[0]\n",
    "\n",
    "    val_path = train_df[train_df['fold'] == fold]['path'].values\n",
    "    val_label = train_df[train_df['fold'] == fold]['group'].values\n",
    "\n",
    "    print(f'FOLD {fold}, Face Count {tr_label.max()}')\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        XunFeiDataset(tr_path, tr_label,\n",
    "                            transforms.Compose([\n",
    "                            transforms.Resize((300, 300)),\n",
    "                            transforms.RandomHorizontalFlip(),\n",
    "                            transforms.RandomVerticalFlip(),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ),\n",
    "        batch_size=10, shuffle=True, num_workers=5,\n",
    "    )\n",
    "\n",
    "    val_loader = torch.utils.data.DataLoader(\n",
    "        XunFeiDataset(val_path, val_label,\n",
    "                            transforms.Compose([\n",
    "                            transforms.Resize((300, 300)),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ),\n",
    "        batch_size=10, shuffle=False, num_workers=5,\n",
    "    )\n",
    "\n",
    "    model = XunFeiNet(tr_label.max()+1).cuda()\n",
    "    criterion = nn.CrossEntropyLoss().cuda()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), 0.0003)\n",
    "    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.85)\n",
    "    fold_f1 = 0.0\n",
    "    \n",
    "    for epoch in range(8):\n",
    "        print('Epoch: ', epoch)\n",
    "\n",
    "        train(train_loader, model, criterion, optimizer, epoch)\n",
    "        scheduler.step()\n",
    "\n",
    "        val_feats = validate(val_loader, model)\n",
    "        val_feats = np.vstack(val_feats)\n",
    "        # val_feats = normalize(val_feats)\n",
    "\n",
    "        val_distance = []\n",
    "        for feat in val_feats:\n",
    "            dis = np.dot(feat, val_feats.T)\n",
    "            val_distance.append(dis)\n",
    "\n",
    "        best_threahold, best_f1 = 0, 0\n",
    "        for threahold in np.linspace(0.5, 0.99, 20):\n",
    "            val_submit = []\n",
    "            for dis in val_distance[:]:\n",
    "                pred = np.where(dis > threahold)[0]\n",
    "                if len(pred) == 1:\n",
    "                    ids = dis.argsort()[::-1]\n",
    "                    pred = [x for x in ids[dis[ids] > 0.1]][:2]\n",
    "\n",
    "                val_submit.append(pred)\n",
    "\n",
    "            val_f1s = []\n",
    "            for x, pred in zip(val_label, val_submit):\n",
    "                label = np.where(val_label == x)[0]\n",
    "                val_f1 = len(set(pred) & set(label)) / len(set(pred) | set(label)) \n",
    "                val_f1s.append(val_f1)\n",
    "\n",
    "            if best_f1 < np.mean(val_f1s):\n",
    "                best_f1 = np.mean(val_f1s)\n",
    "                best_threahold = threahold\n",
    "\n",
    "        if fold_f1 < best_f1:\n",
    "            torch.save(model.state_dict(), 'model_{0}.pt'.format(fold))\n",
    "            fold_f1 = best_f1\n",
    "        print('Val', best_threahold, best_f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:33:41.657927Z",
     "iopub.status.busy": "2021-08-13T00:33:41.657319Z",
     "iopub.status.idle": "2021-08-13T00:33:41.677526Z",
     "shell.execute_reply": "2021-08-13T00:33:41.677044Z",
     "shell.execute_reply.started": "2021-08-13T00:33:41.657878Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_path = glob.glob('./电商图像检索_数据集/test/*')\n",
    "test_path.sort()\n",
    "test_path = np.array(test_path)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path, [0]*len(test_path),\n",
    "                        transforms.Compose([\n",
    "                        transforms.Resize((300, 300)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ),\n",
    "    batch_size=50, shuffle=False, num_workers=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:37:22.102963Z",
     "iopub.status.busy": "2021-08-13T00:37:22.102377Z",
     "iopub.status.idle": "2021-08-13T00:38:52.567218Z",
     "shell.execute_reply": "2021-08-13T00:38:52.566444Z",
     "shell.execute_reply.started": "2021-08-13T00:37:22.102915Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_feats_fold = []\n",
    "for fold, path in zip(range(5),\n",
    "    ['model_0.pt', 'model_1.pt', 'model_2.pt', 'model_3.pt', 'model_4.pt']):\n",
    "    tr_path = train_df[train_df['fold'] != fold]['path'].values\n",
    "    tr_label = train_df[train_df['fold'] != fold]['group']\n",
    "    tr_label = pd.factorize(tr_label)[0]\n",
    "    model = XunFeiNet(tr_label.max()+1).cuda()\n",
    "    \n",
    "    model.load_state_dict(torch.load(path))\n",
    "    model.eval()\n",
    "    test_feats = []\n",
    "    with torch.no_grad():\n",
    "        for data in test_loader:\n",
    "            data = data[0].cuda()\n",
    "            feat = model(data)\n",
    "            test_feats.append(feat.data.cpu().numpy())\n",
    "\n",
    "    test_feats = np.vstack(test_feats)\n",
    "    test_feats = normalize(test_feats)\n",
    "    test_feats_fold.append(test_feats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:39:45.751002Z",
     "iopub.status.busy": "2021-08-13T00:39:45.750425Z",
     "iopub.status.idle": "2021-08-13T00:39:45.846367Z",
     "shell.execute_reply": "2021-08-13T00:39:45.845441Z",
     "shell.execute_reply.started": "2021-08-13T00:39:45.750953Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_feats = np.stack(test_feats_fold).mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:46:18.345740Z",
     "iopub.status.busy": "2021-08-13T00:46:18.345172Z",
     "iopub.status.idle": "2021-08-13T00:46:21.768345Z",
     "shell.execute_reply": "2021-08-13T00:46:21.767551Z",
     "shell.execute_reply.started": "2021-08-13T00:46:18.345692Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_submit = []\n",
    "for path, feat in zip(test_path[:], test_feats[:]):\n",
    "    dis = np.dot(feat, test_feats.T)\n",
    "    pred = [x.split('/')[-1] for x in test_path[np.where(dis > 0.3)[0]]]\n",
    "    if len(pred) <= 1:\n",
    "        ids = dis.argsort()[::-1]\n",
    "        pred = [x.split('/')[-1] for x in test_path[ids[:2]]]\n",
    "    \n",
    "    test_submit.append([\n",
    "        path.split('/')[-1],\n",
    "        pred\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:46:21.770171Z",
     "iopub.status.busy": "2021-08-13T00:46:21.769649Z",
     "iopub.status.idle": "2021-08-13T00:46:21.774898Z",
     "shell.execute_reply": "2021-08-13T00:46:21.774311Z",
     "shell.execute_reply.started": "2021-08-13T00:46:21.770138Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.35748896"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dis.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:46:21.776539Z",
     "iopub.status.busy": "2021-08-13T00:46:21.776069Z",
     "iopub.status.idle": "2021-08-13T00:46:22.224568Z",
     "shell.execute_reply": "2021-08-13T00:46:22.223262Z",
     "shell.execute_reply.started": "2021-08-13T00:46:21.776509Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_submit = pd.DataFrame(test_submit, columns=['name', 'label'])\n",
    "test_submit['label'] = test_submit['label'].apply(lambda x: ' '.join(x))\n",
    "test_submit\n",
    "test_submit.to_csv('submit.csv',index=None)\n",
    "\n",
    "# 提交线上0.580"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:44:58.308778Z",
     "iopub.status.busy": "2021-08-13T00:44:58.308171Z",
     "iopub.status.idle": "2021-08-13T00:45:08.785312Z",
     "shell.execute_reply": "2021-08-13T00:45:08.784523Z",
     "shell.execute_reply.started": "2021-08-13T00:44:58.308730Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 做了扩展查询\n",
    "test_submit = []\n",
    "for path, feat in zip(test_path[:], test_feats[:]):\n",
    "    dis = np.dot(feat, test_feats.T)\n",
    "\n",
    "    feat_qe = np.multiply(dis[np.argsort(dis)[::-1][:2]].reshape(2, -1),\n",
    "            test_feats[np.argsort(dis)[::-1][:2]]).mean(0)\n",
    "    dis = np.dot(feat_qe, test_feats.T)\n",
    "    \n",
    "    pred = [x.split('/')[-1] for x in test_path[np.where(dis > 0.1)[0]]]\n",
    "    if len(pred) <= 1:\n",
    "        ids = dis.argsort()[::-1]\n",
    "        pred = [x.split('/')[-1] for x in test_path[ids[:2]]]\n",
    "    \n",
    "    test_submit.append([\n",
    "        path.split('/')[-1],\n",
    "        pred\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:45:09.534870Z",
     "iopub.status.busy": "2021-08-13T00:45:09.534291Z",
     "iopub.status.idle": "2021-08-13T00:45:09.552956Z",
     "shell.execute_reply": "2021-08-13T00:45:09.552362Z",
     "shell.execute_reply.started": "2021-08-13T00:45:09.534822Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_submit = pd.DataFrame(test_submit, columns=['name', 'label'])\n",
    "test_submit['label'] = test_submit['label'].apply(lambda x: ' '.join(x))\n",
    "test_submit\n",
    "test_submit.to_csv('submit_qe.csv',index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2021-08-13T00:44:42.946410Z",
     "iopub.status.busy": "2021-08-13T00:44:42.945839Z",
     "iopub.status.idle": "2021-08-13T00:44:42.953325Z",
     "shell.execute_reply": "2021-08-13T00:44:42.952231Z",
     "shell.execute_reply.started": "2021-08-13T00:44:42.946363Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.12779833"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dis.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
